Skip to content

[ROCm][GPTQ][Bugfix] Fix GPTQ GEMM kernel output zeroing race condition#30719

Merged
vllm-bot merged 10 commits intovllm-project:mainfrom
ROCm:fix/gptq-rocm
Dec 29, 2025
Merged

[ROCm][GPTQ][Bugfix] Fix GPTQ GEMM kernel output zeroing race condition#30719
vllm-bot merged 10 commits intovllm-project:mainfrom
ROCm:fix/gptq-rocm

Conversation

@AndreasKaratzas
Copy link
Copy Markdown
Collaborator

@AndreasKaratzas AndreasKaratzas commented Dec 15, 2025

Summary

Fixes a race condition in the GPTQ GEMM kernels that caused incorrect results when input_size > BLOCK_KN_SIZE (128).

Problem

The GPTQ GEMM kernels use multiple thread blocks along the k-dimension (gridDim.z > 1) that accumulate partial results via atomicAdd. The output tensor was being zeroed inside the kernel by blockIdx.z == 0:

if (blockIdx.z == 0) {
  for (int m = 0; m < m_count; m++)
    *((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0;
}
__syncthreads();  // Only syncs within block, not across blocks!

Since __syncthreads() only synchronizes threads within the same block, not across different blocks, this creates a race condition where:

  • Blocks with z > 0 may atomicAdd their results before block z=0 finishes zeroing
  • Block z=0 may overwrite results that other blocks have already added

This caused numerical errors up to 45x the expected values, particularly when:

  • input_size > 128 (triggers multiple k-blocks)
  • Errors concentrated in specific rows at m-block boundaries

Solution

  1. Pre-zero the output tensor using torch::zeros() instead of torch::empty()
  2. Remove the in-kernel zeroing logic from all GPTQ GEMM kernel variants

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
@mergify mergify bot added the rocm Related to AMD ROCm label Dec 15, 2025
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This PR correctly identifies and fixes a race condition in the GPTQ GEMM kernels by moving the output tensor zeroing from inside the CUDA kernels to the host side before the kernel launch. The changes are in the right direction, but the fix is incomplete. I've left a critical comment pointing out that the in-kernel zeroing logic also needs to be removed from several other kernel variants to fully resolve the bug.

I am having trouble creating individual review comments. Click here to see my feedback.

csrc/quantization/gptq/q_gemm.cu (236-239)

critical

This change correctly addresses the race condition for the 4-bit kernel. However, the fix is incomplete as the same problematic in-kernel zeroing logic exists in several other kernel variants. This is a critical issue because the bug will persist for other quantization bit-widths.

Please apply the same fix (i.e., remove the in-kernel zeroing) to the following kernels as well:

  • gemm_half_q_half_gptq_2bit_kernel (lines 375-378)
  • gemm_half_q_half_gptq_3bit_kernel (lines 497-500)
  • gemm_half_q_half_gptq_8bit_kernel (lines 626-629)
  • gemm_half_q_half_alt_4bit_kernel (lines 1227-1229)
  • gemm_half_q_half_alt_8bit_kernel (lines 1322-1324)

Additionally, it's better to remove this dead code entirely rather than commenting it out.

@tjtanaa tjtanaa requested a review from mgoin December 16, 2025 05:53
@AndreasKaratzas
Copy link
Copy Markdown
Collaborator Author

Code Review

This PR correctly identifies and fixes a race condition in the GPTQ GEMM kernels by moving the output tensor zeroing from inside the CUDA kernels to the host side before the kernel launch. The changes are in the right direction, but the fix is incomplete. I've left a critical comment pointing out that the in-kernel zeroing logic also needs to be removed from several other kernel variants to fully resolve the bug.

I am having trouble creating individual review comments. Click here to see my feedback.

Already done :)

@robertgshaw2-redhat
Copy link
Copy Markdown
Collaborator

maybe we should just deprecate this kernel...

@AndreasKaratzas
Copy link
Copy Markdown
Collaborator Author

maybe we should just deprecate this kernel...

For now, I think we can merge this PR though, since it resolves the GPTQ test bug on ROCm.

Copy link
Copy Markdown
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes sense and simplifies the kernel, LGTM! I do think we were planning to deprecate this kernel now that @jinzhen-lin added SM75 support for Marlin #29901, but if ROCm needs this kernel we can keep it for now. I would highly recommend the ROCm team investigating if they can reuse the Marlin kernels

@mgoin mgoin added bug Something isn't working ready ONLY add when PR is ready to merge/full CI is needed labels Dec 26, 2025
@mgoin mgoin enabled auto-merge (squash) December 26, 2025 17:34
@AndreasKaratzas
Copy link
Copy Markdown
Collaborator Author

@mgoin There are some failures due to OSError: You are trying to access a gated repo. This looks like an HF issue. Can we merge this once the rest of the checks are done?

@AndreasKaratzas
Copy link
Copy Markdown
Collaborator Author

cc @tjtanaa Can we merge this one too? The failing tests are known to be problematic.

@vllm-bot vllm-bot merged commit 3ecfdc3 into vllm-project:main Dec 29, 2025
88 of 93 checks passed
@AndreasKaratzas AndreasKaratzas deleted the fix/gptq-rocm branch December 29, 2025 18:06
yiliu30 pushed a commit to yiliu30/vllm-fork that referenced this pull request Dec 30, 2025
akh64bit pushed a commit to akh64bit/vllm that referenced this pull request Jan 16, 2026
dsuhinin pushed a commit to dsuhinin/vllm that referenced this pull request Jan 21, 2026
…on (vllm-project#30719)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
Signed-off-by: dsuhinin <suhinin.dmitriy@gmail.com>
ItzDEXX pushed a commit to ItzDEXX/vllm that referenced this pull request Feb 19, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants